{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8df77d7f",
   "metadata": {},
   "source": [
    "# Lesson 3: Object Detection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86d1b37f",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px\"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ff2703b",
   "metadata": {},
   "source": [
    "* In this classroom, the libraries have been already installed for you.\n",
    "* If you would like to run this code on your own machine, you need to install the following:\n",
    "    ```\n",
    "    !pip install -q comet_ml transformers ultralytics torch\n",
    "    ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f33b83e5",
   "metadata": {},
   "source": [
    "### Set up Comet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29573cec-e264-41d6-b0b0-6f94d28feecb",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import comet_ml"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e527f77",
   "metadata": {},
   "source": [
    "Info about ['Comet'](https://www.comet.com/site/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5edef5-2a8e-42b6-91d3-28b1da97e765",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "comet_ml.init(anonymous=True, project_name=\"3: OWL-ViT + SAM\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9678e1-4732-4b4e-b8fc-481d5b993afa",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "exp = comet_ml.Experiment()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc143578",
   "metadata": {},
   "source": [
    "### Load the image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c4aafd-55dc-4fb1-981a-a8f984ad3f36",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# To display the image\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d985060-1b7c-48d0-afd6-367b3bde7e74",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "logged_artifact = exp.get_artifact(\"L3-data\", \"anmorgan24\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ef86770",
   "metadata": {},
   "source": [
    ">Note: the images referenced in this notebook have already been uploaded to the Jupyter directory, in this classroom, for your convenience. For further details, please refer to the **Appendix** section located at the end of the lessons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51236f01-5e57-4b3d-a3da-3bdb31af9a77",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "local_artifact = logged_artifact.download(\"./\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01f65a28-3974-4b5e-8182-44537aef0a32",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Display the images\n",
    "raw_image = Image.open(\"L3_data/dogs.jpg\")\n",
    "raw_image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "019139e8-11d1-4b37-b258-9eafffe2f5fd",
   "metadata": {},
   "source": [
    "### Get bounding boxes with OWL-ViT object detection model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb2dab86",
   "metadata": {},
   "source": [
    ">Note: `pipeline` is already installed for you in this classroom."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c886b6d-dc91-4f2b-93e0-5a3178df04f2",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from transformers import pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d80f5ee-6b24-47f1-8f65-74dafe0de335",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "OWL_checkpoint = \"./models/google/owlvit-base-patch32\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "120a550b",
   "metadata": {},
   "source": [
    "Info about ['google/owlvit-base-patch32'](https://huggingface.co/google/owlvit-base-patch32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0047c123",
   "metadata": {},
   "source": [
    "* Build the pipeline for the detector model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4cc2939-79d6-4f0d-8d7d-c462c5e27e55",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Load the model\n",
    "detector = pipeline(\n",
    "    model= OWL_checkpoint,\n",
    "    task=\"zero-shot-object-detection\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a7a3cc8-697a-4dab-bfaa-9e9428dd39e0",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# What you want to identify in the image\n",
    "text_prompt = \"dog\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00737c4e-4a35-4e13-8c61-cddc5a239c83",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "output = detector(\n",
    "    raw_image,\n",
    "    candidate_labels = [text_prompt]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321b82fe-a74a-4479-9542-7d4ffe0982d3",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Print the output to identify the bounding boxes detected\n",
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b382b6a1",
   "metadata": {},
   "source": [
    "* Use the **util**'s function to prompt boxes in top of the image."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7074fe92",
   "metadata": {},
   "source": [
    ">Note: ```utils``` is an additional file containing the methods that have been already developed for you to be used in this classroom. \n",
    "For further details, please refer to the **Appendix** section located at the end of the lessons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c26cea3b-65fd-412e-8ce6-a9c92c0ff1a3",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import preprocess_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4dc1018-df22-4237-9739-594003c7ea2b",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "input_scores, input_labels, input_boxes = preprocess_outputs(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df173e1c-0e61-47a6-8ffd-a4729b6770e3",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import show_boxes_and_labels_on_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a715c552-0c39-4766-96d5-9125ffaf8a60",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "# Show the image with the bounding boxes\n",
    "show_boxes_and_labels_on_image(\n",
    "    raw_image,\n",
    "    input_boxes[0],\n",
    "    input_labels,\n",
    "    input_scores\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd750f67-a674-44b2-9c2e-d7d734ddf057",
   "metadata": {},
   "source": [
    "### Get segmentation masks using Mobile SAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "715a4509-ffdd-4792-a6cb-0cbd1b469ee3",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Load the SAM model from the imported ultralytics library\n",
    "from ultralytics import SAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35dcdcb9-2a97-486c-a4c3-9299f166bf12",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "SAM_version = \"mobile_sam.pt\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a5f09dd",
   "metadata": {},
   "source": [
    "Info about [mobile_sam.pt](https://docs.ultralytics.com/models/mobile-sam/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c536491-2a1d-43c0-b800-db60611f3e77",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model = SAM(SAM_version)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32055e82",
   "metadata": {},
   "source": [
    "* Generate an array using numpy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97ebfce2-30ef-41c3-a90c-ba8837bbe9bf",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6924c0f1-343b-4e5a-846d-ec9b3071e561",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Create a list of positive labels of same length as the number of predictions generated above\n",
    "labels = np.repeat(1, len(output))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a026253-1288-4c5a-8d40-59c159a3e9d5",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Print the labels\n",
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d64130f-25fb-4408-84d6-34842fad1adb",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "result = model.predict(\n",
    "    raw_image,\n",
    "    bboxes=input_boxes[0],\n",
    "    labels=labels\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "912d3de5-46ed-4a54-b611-e75061d2bb3a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a90de4c0-d0de-44a1-be51-335b30f7ad59",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "masks = result[0].masks.data\n",
    "masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e7ed1bf-aa03-44cf-a3b3-6535d155ebc5",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import show_masks_on_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ba93891-f826-4c8d-9447-6ac061141ada",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Visualize the masks\n",
    "show_masks_on_image(\n",
    "    raw_image,\n",
    "    masks\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9517b557",
   "metadata": {},
   "source": [
    ">Note: Please note that the results obtained from running this notebook may vary slightly from those demonstrated by the instructor in the video. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c5ff161-f446-4aba-b4c1-d2163698289f",
   "metadata": {},
   "source": [
    "### Image Editing: blur out faces"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ae5d111",
   "metadata": {},
   "source": [
    "* Load the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "147e28c3-5290-48c3-80b0-d56a15fd8bdf",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f13d6309-2a05-4a85-9ed3-881cd0d95451",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "image_path = \"L3_data/people.jpg\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "830fcb66",
   "metadata": {},
   "source": [
    ">Note: the images referenced in this notebook have already been uploaded to the Jupyter directory, in this classroom, for your convenience. For further details, please refer to the **Appendix** section located at the end of the lessons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6de06864-3e4b-4bc5-8c43-a2a424067f24",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "raw_image = Image.open(image_path)\n",
    "raw_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bdd0fca-fd4b-4113-8324-043e662c426e",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "raw_image.size"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d5a571a",
   "metadata": {},
   "source": [
    "* Resize the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99a0d63b-38cc-434c-aca9-21fe34f8cdb2",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Calculate width percent to maintain aspect ratio in resize transformation\n",
    "mywidth = 600\n",
    "wpercent = mywidth / float(raw_image.size[0])\n",
    "wpercent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12e95ff6-4cc2-4969-be4a-984e1d8972a2",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "# Calculate height percent to maintain aspect ratio in resize transformation\n",
    "hsize = int( float(raw_image.size[1]) * wpercent )\n",
    "hsize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8c70705-1b17-4ff8-b879-c83a0c737816",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Resize the image\n",
    "raw_image = raw_image.resize([mywidth, hsize])\n",
    "raw_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4d30d1-93bd-4989-97e6-8e98e4e9317a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "raw_image.size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "950ea037-4b8d-4566-9840-fca138475d1d",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Save the resized image\n",
    "image_path_resized = \"people_resized.jpg\"\n",
    "raw_image.save(image_path_resized)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05b8d963-6e65-4bee-9b6a-858a91759397",
   "metadata": {},
   "source": [
    "### Detect faces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5957593-af98-4b2c-875c-f0c3a0d87310",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "candidate_labels = [\"human face\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f5d2b3d-154e-4e66-8f3f-d9c51e350ed8",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Define a new Comet experiment for this new pipeline\n",
    "exp = comet_ml.Experiment()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa0ddd2c-689f-412a-ac34-02ffb0ca4a11",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Log raw image to the experiment\n",
    "_ = exp.log_image(\n",
    "    raw_image,\n",
    "    name = \"Raw image\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f3d6837-2b26-448c-a759-dc78ec41690a",
   "metadata": {},
   "source": [
    "* Create bounding boxes with OWL-ViT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17b064b-45ef-4fed-9355-0c757bfa4b8b",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Apply detector model to the raw image\n",
    "output = detector(\n",
    "    raw_image,\n",
    "    candidate_labels=candidate_labels\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ec3c98-ac31-444e-8e07-8a5aaee6fae7",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "input_scores, input_labels, input_boxes = preprocess_outputs(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dae430e-fbbf-43c1-b375-143ddbefcf45",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Print values of the bounding box coordinates identified\n",
    "input_boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88a77851-67f4-459a-a970-edd6ae1cd75f",
   "metadata": {},
   "source": [
    "#### Log the images and bounding boxes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707708fc-d786-49e7-8c17-29f9cef7d8a6",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "metadata = {\n",
    "    \"OWL prompt\": candidate_labels,\n",
    "    \"SAM version\": SAM_version,\n",
    "    \"OWL Version\": OWL_checkpoint\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb17873-319d-4e59-8bd6-1dc538250211",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import make_bbox_annots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33450381-477c-429a-81ff-47826df5a5bf",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "annotations = make_bbox_annots(\n",
    "    input_scores,\n",
    "    input_labels,\n",
    "    input_boxes,\n",
    "    metadata\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc1832f-4664-4a52-a850-11cacbd02360",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "_ = exp.log_image(\n",
    "    raw_image,\n",
    "    annotations= annotations,\n",
    "    metadata=metadata,\n",
    "    name= \"OWL output\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "788de380-6ca9-4f15-9b6d-a0f4abdf8a41",
   "metadata": {},
   "source": [
    "### Segmentation masks using SAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "258417b7-2ca5-4260-9a9e-dc75685412ce",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "result = model.predict(\n",
    "    image_path_resized,\n",
    "    bboxes=input_boxes[0],\n",
    "    labels=np.repeat(1, len(input_boxes[0]))\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "023b770d-4066-4525-9cca-4a64fb854c43",
   "metadata": {},
   "source": [
    "### Blur entire image first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f13cb88-0d60-4eac-99d6-2da3b10807ce",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from PIL.ImageFilter import GaussianBlur"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41ecce2e-619a-4fb4-b8a4-eaac99ffc5c8",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "blurred_img = raw_image.filter(GaussianBlur(radius=5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d49be80",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "blurred_img "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db8fdee9-5982-4f94-bbb2-18e79de00143",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "masks = result[0].masks.data.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "684fafa9-0a76-4200-8a29-fdcfad821e36",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Create an array of zeroes of the same shape as our image mask\n",
    "total_mask = np.zeros(masks[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edf8b958",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Add each output mask to the total_mask\n",
    "for mask in masks:\n",
    "    total_mask = np.add(total_mask,mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "812df370-3d51-44e2-8366-d3d68f15cf28",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "# Where there is any value other than zero (where any masks exist), show the blurred image\n",
    "# Else, show the original (unblurred) image\n",
    "output = np.where(\n",
    "    np.expand_dims(total_mask != 0, axis=2),\n",
    "    blurred_img,\n",
    "    raw_image\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42fa9908-ceaf-4c52-8e87-d15daa0aeafb",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9bf0f05-ba1a-47c2-a7a1-be0d3928ad76",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Print image with faces blurred\n",
    "plt.imshow(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4638f0e3",
   "metadata": {},
   "source": [
    "* Log this image in the **Comet** platform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1af3bfab-2a10-4019-b420-87c0d791f487",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "metadata = {\n",
    "    \"OWL prompt\": candidate_labels,\n",
    "    \"SAM version\": SAM_version,\n",
    "    \"OWL version\": OWL_checkpoint\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebc12097-032b-42af-a0ab-982f052fb8e7",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "_ = exp.log_image(\n",
    "    output,\n",
    "    name=\"Blurred masks\",\n",
    "    metadata = metadata,\n",
    "    annotations=None\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85f18ad1-2785-4b1c-b1a1-371260de8e7d",
   "metadata": {},
   "source": [
    "### Blur just faces of those not wearing sunglasses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6191d84-003e-4cc9-b2f8-7b5fd73419c8",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# New label\n",
    "candidate_labels = [\"a person without sunglasses\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5538c58b",
   "metadata": {},
   "source": [
    "* Re-run the pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8824b8a-8b4a-4746-9448-c4a68f248643",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "exp = comet_ml.Experiment()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99fd21ba-c3de-467a-9907-bbbdc4df98ef",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "_ = exp.log_image(raw_image, name=\"Raw image\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "363ce773-2bc5-4200-843a-1fc4672098f5",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "output = detector(raw_image, candidate_labels=candidate_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61da9c1f-a5d0-4568-82f9-ec61396f3602",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "input_scores, input_labels, input_boxes = preprocess_outputs(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e34f96bf-3702-413e-a93d-8e75611a94a9",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Print the bounding box coordinates\n",
    "input_boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "136b60fc",
   "metadata": {},
   "source": [
    "* Explore what is happening in the **Comet** platform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67fe8524-0771-4006-857b-880e66becec2",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from utils import make_bbox_annots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b733165-2c43-4f32-ab83-e377d618739d",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "metadata = {\n",
    "    \"OWL prompt\": candidate_labels,\n",
    "    \"SAM version\": SAM_version,\n",
    "    \"OWL version\": OWL_checkpoint,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86860adc-bac5-4ccc-b1f7-6172d006bde4",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "annotations = make_bbox_annots(\n",
    "    input_scores,\n",
    "    input_labels,\n",
    "    input_boxes,\n",
    "    metadata\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82a8c342-b343-4a23-ad29-31d0be1ec8c7",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "_ = exp.log_image(\n",
    "    raw_image,\n",
    "    annotations=annotations,\n",
    "    metadata=metadata,\n",
    "    name=\"OWL output no sunglasses\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88a73b5f-9ec4-4c07-bf57-5fc30608a3d2",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "result = model.predict(\n",
    "    image_path_resized,\n",
    "    bboxes=input_boxes[0],\n",
    "    labels=np.repeat(1, len(input_boxes[0]))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9902cc39-ed82-4c9d-89c6-6795db2ab15e",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from PIL.ImageFilter import GaussianBlur\n",
    "blurred_img = raw_image.filter(GaussianBlur(radius=5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "639f3f04-de9b-469c-939f-c42f98b5bf4b",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "masks = result[0].masks.data.cpu().numpy()\n",
    "\n",
    "total_mask = np.zeros(masks[0].shape)\n",
    "for mask in masks:\n",
    "    total_mask = np.add(total_mask, mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6689e4fe-c03a-4e85-bfec-f57b47162cec",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "# Print the result\n",
    "output = np.where(\n",
    "    np.expand_dims(total_mask != 0, axis=2),\n",
    "    blurred_img,\n",
    "    raw_image\n",
    ")\n",
    "plt.imshow(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ae623b0",
   "metadata": {},
   "source": [
    "* Analyze results in the **Comet** platform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f034f7a-4037-4df9-995c-57cb6ce2b84a",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "metadata = {\n",
    "    \"OWL prompt\": candidate_labels,\n",
    "    \"SAM version\": SAM_version,\n",
    "    \"OWL version\": OWL_checkpoint,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "627a3715-db57-4425-b449-3bde7e028ccb",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "_ = exp.log_image(\n",
    "    output,\n",
    "    name=\"Blurred masks no sunglasses\",\n",
    "    metadata=metadata,\n",
    "    annotations=None\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9ed24b4",
   "metadata": {},
   "source": [
    "### Try yourself! \n",
    "Try the image editing with the following images."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eef6d34",
   "metadata": {},
   "source": [
    ">Note: the images referenced in this notebook have already been uploaded to the Jupyter directory, in this classroom, for your convenience. For further details, please refer to the **Appendix** section located at the end of the lessons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d23168ff-3a5c-4fff-864c-abc46669fda3",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "cafe_img = Image.open(\"L3_data/cafe.jpg\")\n",
    "cafe_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dcedae9-7526-4c12-9c71-fe6582c8f2bf",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "crosswalk_img = Image.open(\"L3_data/crosswalk.jpg\")\n",
    "crosswalk_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77335554-7ec3-44c9-a583-efe644be4975",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "metro_img = Image.open(\"L3_data/metro.jpg\")\n",
    "metro_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf19de66-d5ac-4c00-9c75-34055e8fa1ad",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "friends_img = Image.open(\"L3_data/friends.jpg\")\n",
    "friends_img"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9941707",
   "metadata": {},
   "source": [
    "### Additional Resources\n",
    "\n",
    "* For more on how to use [Comet](https://www.comet.com/site/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L3) for experiment tracking, check out this [Quickstart Guide](https://colab.research.google.com/drive/1jj9BgsFApkqnpPMLCHSDH-5MoL_bjvYq?usp=sharing) and the [Comet Docs](https://www.comet.com/docs/v2/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L3).\n",
    "* This course was based off a set of two blog articles from Comet. Explore them here for more on how to use newer versions of Stable Diffusion in this pipeline, additional tricks to improve your inpainting results, and a breakdown of the pipeline architecture:\n",
    "  * [SAM + Stable Diffusion for Text-to-Image Inpainting](https://www.comet.com/site/blog/sam-stable-diffusion-for-text-to-image-inpainting/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L3)\n",
    "  * [Image Inpainting for SDXL 1.0 Base Model + Refiner](https://www.comet.com/site/blog/image-inpainting-for-sdxl-1-0-base-refiner/?utm_source=dlai&utm_medium=course&utm_campaign=prompt_engineering_for_vision_models&utm_content=dlai_L3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
